{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly[timm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly[timm]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.data.collate import IJEPAMaskCollator\n",
    "from lightly.models import utils\n",
    "from lightly.models.modules.ijepa import IJEPABackbone, IJEPAPredictor\n",
    "from lightly.transforms.ijepa_transform import IJEPATransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class IJEPA(nn.Module):\n",
    "    def __init__(self, vit_encoder, vit_predictor, momentum_scheduler):\n",
    "        super().__init__()\n",
    "        self.encoder = IJEPABackbone.from_vit(vit_encoder)\n",
    "        self.predictor = IJEPAPredictor.from_vit_encoder(\n",
    "            vit_predictor.encoder,\n",
    "            (vit_predictor.image_size // vit_predictor.patch_size) ** 2,\n",
    "        )\n",
    "        self.target_encoder = copy.deepcopy(self.encoder)\n",
    "        self.momentum_scheduler = momentum_scheduler\n",
    "\n",
    "    def forward_target(self, imgs, masks_enc, masks_pred):\n",
    "        with torch.no_grad():\n",
    "            h = self.target_encoder(imgs)\n",
    "            h = F.layer_norm(h, (h.size(-1),))  # normalize over feature-dim\n",
    "            B = len(h)\n",
    "            # -- create targets (masked regions of h)\n",
    "            h = utils.apply_masks(h, masks_pred)\n",
    "            h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc))\n",
    "            return h\n",
    "\n",
    "    def forward_context(self, imgs, masks_enc, masks_pred):\n",
    "        z = self.encoder(imgs, masks_enc)\n",
    "        z = self.predictor(z, masks_enc, masks_pred)\n",
    "        return z\n",
    "\n",
    "    def forward(self, imgs, masks_enc, masks_pred):\n",
    "        z = self.forward_context(imgs, masks_enc, masks_pred)\n",
    "        h = self.forward_target(imgs, masks_enc, masks_pred)\n",
    "        return z, h\n",
    "\n",
    "    def update_target_encoder(\n",
    "        self,\n",
    "    ):\n",
    "        with torch.no_grad():\n",
    "            m = next(self.momentum_scheduler)\n",
    "            for param_q, param_k in zip(\n",
    "                self.encoder.parameters(), self.target_encoder.parameters()\n",
    "            ):\n",
    "                param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "collator = IJEPAMaskCollator(\n",
    "    input_size=(224, 224),\n",
    "    patch_size=32,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = IJEPATransform()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "we ignore object detection annotations by setting target_transform to return 0\n",
    "or create a dataset from a folder containing images or videos:\n",
    "dataset = LightlyDataset(\"path/to/folder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_transform(t):\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torchvision.datasets.VOCDetection(\n",
    "    \"datasets/pascal_voc\",\n",
    "    download=True,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    ")\n",
    "data_loader = torch.utils.data.DataLoader(\n",
    "    dataset, collate_fn=collator, batch_size=10, persistent_workers=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "ema = (0.996, 1.0)\n",
    "ipe_scale = 1.0\n",
    "ipe = len(data_loader)\n",
    "num_epochs = 10\n",
    "momentum_scheduler = (\n",
    "    ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale)\n",
    "    for i in range(int(ipe * num_epochs * ipe_scale) + 1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "vit_for_predictor = torchvision.models.vit_b_32(pretrained=False)\n",
    "vit_for_embedder = torchvision.models.vit_b_32(pretrained=False)\n",
    "model = IJEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.SmoothL1Loss()\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Starting Training\")\n",
    "for epoch in range(num_epochs):\n",
    "    total_loss = 0\n",
    "    for udata, masks_enc, masks_pred in tqdm(data_loader):\n",
    "\n",
    "        def load_imgs():\n",
    "            # -- unsupervised imgs\n",
    "            imgs = udata[0].to(device, non_blocking=True)\n",
    "            masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]\n",
    "            masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]\n",
    "            return (imgs, masks_1, masks_2)\n",
    "\n",
    "        imgs, masks_enc, masks_pred = load_imgs()\n",
    "        z, h = model(imgs, masks_enc, masks_pred)\n",
    "        loss = criterion(z, h)\n",
    "        total_loss += loss.detach()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        model.update_target_encoder()\n",
    "\n",
    "    avg_loss = total_loss / len(data_loader)\n",
    "    print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
